#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"

typedef at::Half bf16;
// typedef at::BFloat16 bf16;

// 元コードには定義が無いが、_N_ が N_ であるとし、N_ = C / H と仮定
// extern定義またはマクロ等で _N_ は定義されているはず。
#ifndef _N_
#define _N_  (C/H)
#endif

template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
                               float *__restrict__ _state,
                               const F *__restrict__ const _r,
                               const F *__restrict__ const _w,
                               const F *__restrict__ const _k,
                               const F *__restrict__ const _v,
                               const F *__restrict__ const _a,
                               const F *__restrict__ const _b,
                               F *__restrict__ const _y)
{
    const int e = blockIdx.x / H;   // e: バッチインデックス
    const int h = blockIdx.x % H;   // h: ヘッドインデックス
    const int i = threadIdx.x;      // i: ベクトル次元内インデックス

    // _stateは (B,H,N_,N_) のような4次元（あるいは (B,H*N_,N_) の2or3次元）想定
    // ここでは (B,H,N_,N_) と解釈し、(e,h,i) に対応するオフセット計算を行う
    // 1ヘッド当たりの状態サイズ = N_*N_
    // h * (N_*N_) でヘッドを進め、
    // i * N_       でiを進める
    // e * (H*N_*N_)でバッチを進める
    _state += e * (H * _N_ * _N_) + h * (_N_ * _N_) + i * _N_;

    float state[_N_];
#pragma unroll
    for (int j = 0; j < _N_; j++)
        state[j] = _state[j];

    __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];

    for (int _t = 0; _t < T; _t++)
    {
        // tは入出力配列(r,w,k,v,a,b,y)にアクセスするためのインデックス
        // (B,T,C)でフラット化していると考えると、
        // e*T*C でバッチを、_t*Cで時間方向を、h*N_ + i でチャネル方向を進める
        const int t = e * T * C + _t * C + h * _N_ + i;

        __syncthreads();
        r[i] = float(_r[t]);
        w[i] = __expf(-__expf(float(_w[t])));
        k[i] = float(_k[t]);
        a[i] = float(_a[t]);
        b[i] = float(_b[t]);
        __syncthreads();

        float sa = 0;
#pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            sa += a[j] * state[j];
        }

        float vv = float(_v[t]);
        float y = 0;
#pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float &s = state[j];
            s = s * w[j] + k[j] * vv + sa * b[j];
            y += s * r[j];
        }
        _y[t] = F(y);
    }

#pragma unroll
    for (int j = 0; j < _N_; j++)
        _state[j] = state[j];
}

void cuda_forward(int B, int T, int C, int H,
                  float *state,
                  bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b,
                  bf16 *y)
{
    assert(H * _N_ == C);

    // B>1 でも動作可能にするため、assertは削除
    // assert(B == 1); // 削除

    // ブロック数 = B*H
    // スレッド数 = _N_
    kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
}
